// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "bucketselector.h"
#include "bucketidfactory.h"
#include <vespa/document/base/documentid.h>
#include <vespa/document/select/node.h>
#include <vespa/document/select/valuenodes.h>
#include <vespa/document/select/visitor.h>
#include <vespa/document/select/branch.h>
#include <vespa/document/select/compare.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <algorithm>

namespace document {

using namespace document::select;

//namespace {
    /**
     * Visitor class that is used for visiting a node tree generated by a
     * document selection expression.
     *
     * The visitor class contains the set of buckets expression can match.
     */
    struct BucketVisitor : public document::select::Visitor {
        const BucketIdFactory& _factory;
        std::vector<BucketId> _buckets;
            // If set to false, the buckets to visit is set in _buckets.
        bool _unknown;

        BucketVisitor(const BucketIdFactory& factory)
            : _factory(factory), _buckets(), _unknown(true) {}

        void visitAndBranch(const document::select::And& node) override {
            BucketVisitor left(_factory);
            node.getLeft().visit(left);
            node.getRight().visit(*this);
                // If either part is unknown we can just return other part.
            if (left._unknown) {
                return;
            }
                // If only left part is known return that part.
            if (_unknown) {
                _buckets.swap(left._buckets);
                _unknown = false;
                return;
            }
            std::vector<BucketId> result;
            std::set_intersection(left._buckets.begin(), left._buckets.end(),
                                  _buckets.begin(), _buckets.end(),
                                  back_inserter(result));
            _buckets.swap(result);
            return;
        }

        void visitOrBranch(const document::select::Or& node) override {
            BucketVisitor left(_factory);
            node.getLeft().visit(left);
            node.getRight().visit(*this);
                // If one part is unknown we have to keep unknown status
            if (left._unknown || _unknown) {
                _unknown = true;
                return;
            }
                // If both parts are known return both sets
            std::vector<BucketId> result;
            std::set_union(left._buckets.begin(), left._buckets.end(),
                           _buckets.begin(), _buckets.end(),
                           back_inserter(result));
            _buckets.swap(result);
        }

        void visitNotBranch(const document::select::Not&) override {
            // Since selected locations doesn't include everything at that
            // location, we can't reverse the selection. Any NOT branch must
            // end up specifying all
        }

        void compare(const select::IdValueNode& node,
                     const select::ValueNode& valnode,
                     const select::Operator& op)
        {
            if (node.getType() == IdValueNode::ALL) {
                const StringValueNode* val(
                        dynamic_cast<const StringValueNode*>(&valnode));
                if (!val) return;
                vespalib::string docId(val->getValue());
                if (op == FunctionOperator::EQ ||
                    !GlobOperator::containsVariables(docId))
                {
                    IdString::UP id(IdString::createIdString(docId));
                    _buckets.push_back(BucketId(58, id->getLocation()));
                    _unknown = false;
                }
            } else if (node.getType() == IdValueNode::USER) {
                const IntegerValueNode* val(
                        dynamic_cast<const IntegerValueNode*>(&valnode));
                if (!val) return;
                UserDocIdString id(vespalib::make_string("userdoc::%lu:", val->getValue()));
                _buckets.push_back(BucketId(32, id.getLocation()));
                _unknown = false;
            } else if (node.getType() == IdValueNode::GROUP) {
                const StringValueNode* val(
                        dynamic_cast<const StringValueNode*>(&valnode));
                if (!val) return;
                vespalib::string group(val->getValue());
                if (op == FunctionOperator::EQ ||
                    !GlobOperator::containsVariables(group))
                {
                    GroupDocIdString id("groupdoc::" + group + ":");
                    _buckets.push_back(BucketId(32, id.getLocation()));
                    _unknown = false;
                }
            } else if (node.getType() == IdValueNode::GID) {
                const StringValueNode* val(
                        dynamic_cast<const StringValueNode*>(&valnode));

                vespalib::string gid(val->getValue());
                if (op == FunctionOperator::EQ ||
                    !GlobOperator::containsVariables(gid))
                {
                    BucketId bid = document::GlobalId::parse(gid).convertToBucketId();
                    _buckets.push_back(BucketId(32, bid.getRawId()));
                    _unknown = false;
                }
            } else if (node.getType() == IdValueNode::BUCKET) {
                const IntegerValueNode* val(
                        dynamic_cast<const IntegerValueNode*>(&valnode));
                if (!val) return;

                BucketId bid(val->getValue());
                if (!bid.getUsedBits()) {
                    bid.setUsedBits(32);
                }
                _buckets.push_back(bid);
                _unknown = false;
            }
        }

        void compare(const select::SearchColumnValueNode& node,
                     const select::ValueNode& valnode,
                     const select::Operator& op) {
            if (op == FunctionOperator::EQ || op == document::select::GlobOperator::GLOB) {
                int bucketCount = 1 << 16;
                const IntegerValueNode* val(
                        dynamic_cast<const IntegerValueNode*>(&valnode));

                int64_t rval = val->getValue();

                for (int i = 0; i < bucketCount; i++) {
                    int64_t column = node.getValue(BucketId(16, i));
                    if (column == rval) {
                        _buckets.push_back(BucketId(16, i));
                    }
                }

                _unknown = false;
            }
        }

        void visitComparison(const document::select::Compare& node) override {
            if (node.getOperator() != document::select::FunctionOperator::EQ &&
                node.getOperator() != document::select::GlobOperator::GLOB)
            {
                return;
            }
            const IdValueNode* lid(dynamic_cast<const IdValueNode*>(
                        &node.getLeft()));
            const SearchColumnValueNode* sc(dynamic_cast<const SearchColumnValueNode*>(
                                                    &node.getLeft()));
            if (lid) {
                compare(*lid, node.getRight(), node.getOperator());
            } else if (sc) {
                compare(*sc, node.getRight(), node.getOperator());
            } else {
                const IdValueNode* rid(dynamic_cast<const IdValueNode*>(
                            &node.getRight()));
                if (rid) {
                    compare(*rid, node.getLeft(), node.getOperator());
                }
            }
        }

        void visitConstant(const document::select::Constant&) override {}
        void visitInvalidConstant(const document::select::InvalidConstant &) override {}
        void visitDocumentType(const document::select::DocType&) override {}
        void visitArithmeticValueNode(const ArithmeticValueNode &) override {}
        void visitFunctionValueNode(const FunctionValueNode &) override {}
        void visitIdValueNode(const IdValueNode &) override {}
        void visitSearchColumnValueNode(const SearchColumnValueNode &) override {}
        void visitFieldValueNode(const FieldValueNode &) override {}
        void visitFloatValueNode(const FloatValueNode &) override {}
        void visitVariableValueNode(const VariableValueNode &) override {}
        void visitIntegerValueNode(const IntegerValueNode &) override {}
        void visitCurrentTimeValueNode(const CurrentTimeValueNode &) override {}
        void visitStringValueNode(const StringValueNode &) override {}
        void visitNullValueNode(const NullValueNode &) override {}
        void visitInvalidValueNode(const InvalidValueNode &) override {}
    };
//}

BucketSelector::BucketSelector(const document::BucketIdFactory& factory)
    : _factory(factory)
{
}

std::unique_ptr<BucketSelector::BucketVector>
BucketSelector::select(const document::select::Node& expression) const
{
    BucketVisitor v(_factory);
    expression.visit(v);
    return std::unique_ptr<BucketVector>(v._unknown
            ? 0 : new BucketVector(v._buckets));
}

} // document
